# This is the code for generating the communication complexity plots for varying second-order heterogeneity

import numpy as np
import matplotlib.pyplot as plt

def sample_spherical_cap(axis, theta, n_samples):
    """
    Sample n_samples unit vectors within a spherical cap of half-angle theta around 'axis'.
    """
    d = axis.shape[0]
    axis = axis / np.linalg.norm(axis)
    cos_theta = np.cos(theta)
    if theta < 1e-6:
        return np.tile(axis, (n_samples, 1))
    samples = []
    while len(samples) < n_samples:
        v = np.random.normal(size=d)
        v /= np.linalg.norm(v)
        if np.dot(v, axis) >= cos_theta:
            samples.append(v)
    return np.array(samples)

def rounds_to_target(mu_list, x_star_list, sigma_noise, M, K, R_max, step_sizes, target_error):
    """
    For each step-size, run Local SGD up to R_max rounds and record the first round
    where the error <= target_error. Return the minimum such round across step-sizes.
    """
    d = mu_list[0].shape[0]
    best_round = R_max
    for eta in step_sizes:
        x = np.zeros(d)
        for r in range(1, R_max+1):
            x_locals = []
            for mu, x_star in zip(mu_list, x_star_list):
                x_local = x.copy()
                for _ in range(K):
                    beta = np.random.multivariate_normal(mu, np.eye(d))
                    y = beta.dot(x_star) + np.random.normal(scale=sigma_noise)
                    grad = (x_local.dot(beta) - y) * beta
                    x_local -= eta * grad
                x_locals.append(x_local)
            x = np.mean(x_locals, axis=0)
            err = np.linalg.norm(x - np.mean(x_star_list, axis=0))
            if err <= target_error:
                best_round = min(best_round, r)
                break
    return best_round

def communication_complexity_experiment(
    d=5, M=50, K=10, R_max=150, sigma_noise=0.1,
    mu0=5.0, R_star=1.0,
    tau_points=20, target_error=0.04,
    step_sizes=None, n_runs=20, seed=10):
    """
    Vary second-order heterogeneity (tau) on x-axis, and plot
    average # of communication rounds needed to reach target_error.
    """
    np.random.seed(seed)
    tau_list = np.linspace(0, 2 * mu0, tau_points)
    step_sizes = np.logspace(-3, -1, 5) if step_sizes is None else step_sizes

    # fix concept heterogeneity
    zeta = 1.0
    phi = 2 * np.arcsin(min(zeta / (2 * R_star), 1.0))
    central_axis = np.random.randn(d)
    central_axis /= np.linalg.norm(central_axis)
    x_dirs = sample_spherical_cap(central_axis, phi, M)
    x_star_list = [R_star * v for v in x_dirs]

    rounds_avg = np.zeros(len(tau_list))
    for i, tau in enumerate(tau_list):
        theta = 2 * np.arcsin(min(tau / (2 * mu0), 1.0))
        runs = []
        for _ in range(n_runs):
            mu_dirs = sample_spherical_cap(central_axis, theta, M)
            mu_list = [mu0 * u for u in mu_dirs]
            r_needed = rounds_to_target(
                mu_list, x_star_list, sigma_noise, M, K, R_max, step_sizes, target_error
            )
            runs.append(r_needed)
        rounds_avg[i] = np.mean(runs)
        print(f"τ={tau:.2f} → avg rounds={rounds_avg[i]:.2f}")

    # Plot
    plt.figure(figsize=(6,4))
    plt.plot(tau_list, rounds_avg, marker='o')
    plt.xlabel('Covariate shift')
    plt.ylabel('Avg. communication rounds to reach error ≤ ' + str(target_error))
    plt.title(f'Comm. complexity vs. 2nd-order heterogeneity')
    plt.grid(True)
    plt.show()

# Run the communication complexity experiment
communication_complexity_experiment()
